{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1e2086ae",
   "metadata": {},
   "source": [
    "The three extensions below are optional, for more information, see\n",
    "- `watermark`:  https://github.com/rasbt/watermark\n",
    "- `pycodestyle_magic`: https://github.com/mattijn/pycodestyle_magic\n",
    "- `nb_black`: https://github.com/dnanhkhoa/nb_black"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "149f964e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext watermark\n",
    "%watermark -p torch,pytorch_lightning,torchmetrics,matplotlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "625d1d05",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext pycodestyle_magic\n",
    "%flake8_on --ignore W291,W293,E703,E402,E999 --max_line_length=100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f52f4600",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext nb_black"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9dfafe15",
   "metadata": {},
   "source": [
    "<a href=\"https://pytorch.org\"><img src=\"https://raw.githubusercontent.com/pytorch/pytorch/master/docs/source/_static/img/pytorch-logo-dark.svg\" width=\"90\"/></a> &nbsp; &nbsp;&nbsp;&nbsp;<a href=\"https://www.pytorchlightning.ai\"><img src=\"https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/docs/source/_static/images/logo.svg\" width=\"150\"/></a>\n",
    "\n",
    "# TITLE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "501948bd",
   "metadata": {},
   "source": [
    "- DESCRIPTION\n",
    "\n",
    "\n",
    "### References\n",
    "\n",
    "- ???"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7e16245",
   "metadata": {},
   "source": [
    "## General settings and hyperparameters"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3915fdd",
   "metadata": {},
   "source": [
    "- Here, we specify some general hyperparameter values and general settings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f540e31",
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 256\n",
    "NUM_EPOCHS = 10\n",
    "LEARNING_RATE = 0.005\n",
    "NUM_WORKERS = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b472b8d9",
   "metadata": {},
   "source": [
    "- Note that using multiple workers can sometimes cause issues with too many open files in PyTorch for small datasets. If we have problems with the data loader later, try setting `NUM_WORKERS = 0` and reload the notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a85efca2",
   "metadata": {},
   "source": [
    "## Implementing a Neural Network using PyTorch Lightning's `LightningModule`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "041bfee2",
   "metadata": {},
   "source": [
    "- In this section, we set up the main model architecture using the `LightningModule` from PyTorch Lightning.\n",
    "- In essence, `LightningModule` is a wrapper around a PyTorch module.\n",
    "- We start with defining our neural network model in pure PyTorch, and then we use it in the `LightningModule` to get all the extra benefits that PyTorch Lightning provides."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cda97b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# UNIQUE MODEL CODE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d8c0e05",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../code_lightningmodule/lightningmodule_classifier_basic.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2961a362",
   "metadata": {},
   "source": [
    "## Setting up the dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d86805b0",
   "metadata": {},
   "source": [
    "- In this section, we are going to set up our dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6fa61b4",
   "metadata": {},
   "source": [
    "### Inspecting the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "087f0762",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../code_dataset/dataset_???_check.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8305a40",
   "metadata": {},
   "source": [
    "### Performance baseline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8c4d897",
   "metadata": {},
   "source": [
    "- Especially for imbalanced datasets, it's pretty helpful to compute a performance baseline.\n",
    "- In classification contexts, a useful baseline is to compute the accuracy for a scenario where the model always predicts the majority class -- we want our model to be better than that!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "993ad0d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../code_dataset/performance_baseline.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dab874ca",
   "metadata": {},
   "source": [
    "## A quick visual check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a6f0aa6",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../code_dataset/plot_visual-check_basic.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f40449e",
   "metadata": {},
   "source": [
    "### Setting up a `DataModule`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "734560ae",
   "metadata": {},
   "source": [
    "- There are three main ways we can prepare the dataset for Lightning. We can\n",
    "  1. make the dataset part of the model;\n",
    "  2. set up the data loaders as usual and feed them to the fit method of a Lightning Trainer -- the Trainer is introduced in the following subsection;\n",
    "  3. create a LightningDataModule.\n",
    "- Here, we will use approach 3, which is the most organized approach. The `LightningDataModule` consists of several self-explanatory methods, as we can see below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5c8cf5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../code_lightningmodule/datamodule_???_basic.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25132aa4",
   "metadata": {},
   "source": [
    "- Note that the `prepare_data` method is usually used for steps that only need to be executed once, for example, downloading the dataset; the `setup` method defines the dataset loading -- if we run our code in a distributed setting, this will be called on each node / GPU. \n",
    "- Next, let's initialize the `DataModule`; we use a random seed for reproducibility (so that the data set is shuffled the same way when we re-execute this code):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17132fed",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(1) \n",
    "data_module = DataModule(data_path='./data')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ad8293d",
   "metadata": {},
   "source": [
    "## Training the model using the PyTorch Lightning Trainer class"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25c2b2fb",
   "metadata": {},
   "source": [
    "- Next, we initialize our model.\n",
    "- Also, we define a call back to obtain the model with the best validation set performance after training.\n",
    "- PyTorch Lightning offers [many advanced logging services](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html) like Weights & Biases. However, here, we will keep things simple and use the `CSVLogger`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1acb2f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "pytorch_model = PyTorchModel(\n",
    "    ???\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ca2ccdd",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../code_lightningmodule/logger_csv_acc_basic.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64857a91",
   "metadata": {},
   "source": [
    "- Now it's time to train our model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61f9f037",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../code_lightningmodule/trainer_nb_basic.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb7a9818",
   "metadata": {},
   "source": [
    "## Evaluating the model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5556da91",
   "metadata": {},
   "source": [
    "- After training, let's plot our training ACC and validation ACC using pandas, which, in turn, uses matplotlib for plotting (PS: you may want to check out [more advanced logger](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html) later on, which take care of it for us):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b424d39d",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../code_lightningmodule/logger_csv_plot_basic.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee93a525",
   "metadata": {},
   "source": [
    "- The `trainer` automatically saves the model with the best validation accuracy automatically for us, we which we can load from the checkpoint via the `ckpt_path='best'` argument; below we use the `trainer` instance to evaluate the best model on the test set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d02f1378",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.test(model=lightning_model, datamodule=data_module, ckpt_path='best')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6cd101f",
   "metadata": {},
   "source": [
    "## Predicting labels of new data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffed918e",
   "metadata": {},
   "source": [
    "- We can use the `trainer.predict` method either on a new `DataLoader` (`trainer.predict(dataloaders=...)`) or `DataModule` (`trainer.predict(datamodule=...)`) to apply the model to new data.\n",
    "- Alternatively, we can also manually load the best model from a checkpoint as shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e018ee1",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = trainer.checkpoint_callback.best_model_path\n",
    "print(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2548be4",
   "metadata": {},
   "outputs": [],
   "source": [
    "lightning_model = LightningModel.load_from_checkpoint(path, model=pytorch_model)\n",
    "lightning_model.eval();"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82eea81a",
   "metadata": {},
   "source": [
    "- For simplicity, we reused our existing `pytorch_model` above. However, we could also reinitialize the `pytorch_model`, and the `.load_from_checkpoint` method would load the corresponding model weights for us from the checkpoint file.\n",
    "- Now, below is an example applying the model manually. Here, pretend that the `test_dataloader` is a new data loader."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5781814",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../code_lightningmodule/datamodule_testloader.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad3600a0",
   "metadata": {},
   "source": [
    "- As an internal check, if the model was loaded correctly, the test accuracy below should be identical to the test accuracy we saw earlier in the previous section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c31bbca9",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_acc = acc.compute()\n",
    "print(f'Test accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e2265c7",
   "metadata": {},
   "source": [
    "## Inspecting Failure Cases"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4913044c",
   "metadata": {},
   "source": [
    "- In practice, it is often informative to look at failure cases like wrong predictions for particular training instances as it can give us some insights into the model behavior and dataset.\n",
    "- Inspecting failure cases can sometimes reveal interesting patterns and even highlight dataset and labeling issues."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1b5c4a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# In the case of ???, the class label mapping\n",
    "# ???\n",
    "class_dict = {???}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18ab0294",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../code_lightningmodule/plot_failurecases_basic.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "481b4f08",
   "metadata": {},
   "source": [
    "- In addition to inspecting failure cases visually, it is also informative to look at which classes the model confuses the most via a confusion matrix:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2cb8af8",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../code_lightningmodule/plot_confusion-matrix_basic.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d48d1ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "%watermark --iversions"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
